DIRICHLET_MULTINOM
Overview
The DIRICHLET_MULTINOM function computes statistical properties of the Dirichlet-multinomial distribution, a compound probability distribution that arises when category probabilities are uncertain. Also known as the Dirichlet compound multinomial (DCM) or multivariate Pólya distribution, it models scenarios where observations follow a multinomial distribution with probabilities drawn from a Dirichlet distribution.
This distribution is constructed by first drawing a probability vector \mathbf{p} from a Dirichlet distribution with concentration parameters \boldsymbol{\alpha} = (\alpha_1, \ldots, \alpha_K), then drawing counts from a multinomial distribution with n trials and probability vector \mathbf{p}. The probability mass function is:
P(\mathbf{x} \mid n, \boldsymbol{\alpha}) = \frac{\Gamma(\alpha_0) \Gamma(n+1)}{\Gamma(n + \alpha_0)} \prod_{k=1}^{K} \frac{\Gamma(x_k + \alpha_k)}{\Gamma(\alpha_k) \Gamma(x_k + 1)}
where \alpha_0 = \sum_{k=1}^{K} \alpha_k is the sum of concentration parameters, and \mathbf{x} = (x_1, \ldots, x_K) represents counts in each of K categories with \sum x_k = n.
The expected value for category i is E(X_i) = n \alpha_i / \alpha_0, and the variance is:
\text{Var}(X_i) = n \frac{\alpha_i}{\alpha_0} \left(1 - \frac{\alpha_i}{\alpha_0}\right) \frac{n + \alpha_0}{1 + \alpha_0}
The distribution exhibits overdispersion relative to the multinomial—the variance is inflated by a factor of (n + \alpha_0)/(1 + \alpha_0). This makes it suitable for modeling count data with extra variability, such as word frequencies in documents or allele counts in population genetics. The concentration parameter \alpha_0 controls the degree of overdispersion: smaller values produce greater variability, while larger values make the distribution approach a standard multinomial.
This implementation uses SciPy’s dirichlet_multinomial module and supports computing the PMF, log-PMF, mean, variance, and covariance matrix. For additional theoretical background, see the Wikipedia article on the Dirichlet-multinomial distribution.
This example function is provided as-is without any representation of accuracy.
Excel Usage
=DIRICHLET_MULTINOM(x, alpha, n, dm_method)
x(list[list], optional, default: null): 2D list of integer counts for each category. Required for pmf and logpmf methods.alpha(list[list], optional, default: null): 2D list of concentration parameters (positive floats). Each row represents parameters for one distribution.n(list[list], optional, default: null): 2D list containing the number of trials for each distribution. Each row contains one integer. Required for all methods except cov.dm_method(str, optional, default: “pmf”): Computation method to use.
Returns (list[list]): 2D list of results, or error message string.
Examples
Example 1: Basic PMF calculation with uniform concentration
Inputs:
| x | alpha | n | dm_method | ||||
|---|---|---|---|---|---|---|---|
| 2 | 3 | 5 | 1 | 1 | 1 | 10 | pmf |
Excel formula:
=DIRICHLET_MULTINOM({2,3,5}, {1,1,1}, {10}, "pmf")
Expected output:
| Result |
|---|
| 0.0152 |
Example 2: Log-PMF calculation for same distribution
Inputs:
| x | alpha | n | dm_method | ||||
|---|---|---|---|---|---|---|---|
| 2 | 3 | 5 | 1 | 1 | 1 | 10 | logpmf |
Excel formula:
=DIRICHLET_MULTINOM({2,3,5}, {1,1,1}, {10}, "logpmf")
Expected output:
| Result |
|---|
| -4.1897 |
Example 3: Expected mean counts for weighted concentration
Inputs:
| alpha | n | dm_method | ||
|---|---|---|---|---|
| 2 | 3 | 5 | 10 | mean |
Excel formula:
=DIRICHLET_MULTINOM({2,3,5}, {10}, "mean")
Expected output:
| Result | ||
|---|---|---|
| 2 | 3 | 5 |
Example 4: Variance for weighted concentration
Inputs:
| alpha | n | dm_method | ||
|---|---|---|---|---|
| 2 | 3 | 5 | 10 | var |
Excel formula:
=DIRICHLET_MULTINOM({2,3,5}, {10}, "var")
Expected output:
| Result | ||
|---|---|---|
| 2.9091 | 3.8182 | 4.5455 |
Example 5: Covariance matrix for three categories
Inputs:
| alpha | dm_method | ||
|---|---|---|---|
| 2 | 3 | 5 | cov |
Excel formula:
=DIRICHLET_MULTINOM({2,3,5}, "cov")
Expected output:
| Result | ||
|---|---|---|
| 0.16 | -0.06 | -0.1 |
| -0.06 | 0.21 | -0.15 |
| -0.1 | -0.15 | 0.25 |
Python Code
from scipy.stats import dirichlet_multinomial as scipy_dirichlet_multinomial
def dirichlet_multinom(x=None, alpha=None, n=None, dm_method='pmf'):
"""
Computes the probability mass function, log probability mass function, mean, variance, or covariance of the Dirichlet multinomial distribution.
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet_multinomial.html
This example function is provided as-is without any representation of accuracy.
Args:
x (list[list], optional): 2D list of integer counts for each category. Required for pmf and logpmf methods. Default is None.
alpha (list[list], optional): 2D list of concentration parameters (positive floats). Each row represents parameters for one distribution. Default is None.
n (list[list], optional): 2D list containing the number of trials for each distribution. Each row contains one integer. Required for all methods except cov. Default is None.
dm_method (str, optional): Computation method to use. Valid options: PMF, Log PMF, Mean, Variance, Covariance. Default is 'pmf'.
Returns:
list[list]: 2D list of results, or error message string.
"""
def to2d(val):
if val is None:
return None
return [[val]] if not isinstance(val, list) else val
def to_float_list(arr):
if hasattr(arr, 'tolist'):
arr = arr.tolist()
if isinstance(arr, (float, int)):
return [float(arr)]
return [float(v) for v in arr]
valid_methods = {'pmf', 'logpmf', 'mean', 'var', 'cov'}
if dm_method not in valid_methods:
return f"Error: Invalid method '{dm_method}'. Must be one of {sorted(valid_methods)}."
if alpha is None:
return "Error: Invalid input: alpha is required."
alpha = to2d(alpha)
if not isinstance(alpha, list) or not all(isinstance(row, list) and len(row) > 0 for row in alpha):
return "Error: alpha must be a 2D list of positive floats."
if len(alpha) < 1:
return "Error: alpha must have at least one row."
try:
alpha = [[float(v) for v in row] for row in alpha]
except (TypeError, ValueError):
return "alpha must be a 2D list of positive floats."
if any(any(v <= 0 for v in row) for row in alpha):
return "alpha must be a 2D list of positive floats."
# n is required for pmf/logpmf/mean/var; for cov, default to n=1 if omitted
if dm_method != 'cov':
if n is None:
return "Error: Invalid input: n is required."
n = to2d(n)
if not isinstance(n, list) or len(n) != len(alpha):
return "Error: n must be a 2D list with the same number of rows as alpha."
for n_row in n:
if not isinstance(n_row, list) or len(n_row) != 1:
return "Error: Each row of n must contain exactly one integer."
try:
n = [[int(val[0])] for val in n]
except (TypeError, ValueError):
return "Error: n must contain integers."
if any(val[0] < 0 for val in n):
return "Error: n must contain non-negative integers."
else:
if n is not None:
n = to2d(n)
if not isinstance(n, list) or len(n) != len(alpha):
return "n must be a 2D list with the same number of rows as alpha."
for n_row in n:
if not isinstance(n_row, list) or len(n_row) != 1:
return "Each row of n must contain exactly one integer."
try:
n = [[int(val[0])] for val in n]
except (TypeError, ValueError):
return "n must contain integers."
if any(val[0] < 0 for val in n):
return "n must contain non-negative integers."
if dm_method in {'pmf', 'logpmf'}:
if x is None:
return "Error: Invalid input: x is required for pmf/logpmf."
x = to2d(x)
if not isinstance(x, list) or len(x) != len(alpha):
return "Error: x must be a 2D list with the same number of rows as alpha."
for row in x:
if not isinstance(row, list) or len(row) != len(alpha[0]):
return "Error: Each row of x must have the same length as alpha rows."
try:
if any(int(val) < 0 for val in row):
return "Error: x must contain non-negative integers."
except (TypeError, ValueError):
return "Error: x must contain integers."
results = []
for i, alpha_row in enumerate(alpha):
try:
if dm_method == 'cov':
n_val = 1 if n is None else n[i][0]
else:
n_val = n[i][0]
if dm_method in {'pmf', 'logpmf'}:
row_sum = sum(int(v) for v in x[i])
if row_sum != n_val:
return "Error: Invalid input: each row of x must sum to n."
dist = scipy_dirichlet_multinomial(alpha=alpha_row, n=n_val)
if dm_method == 'pmf':
res = dist.pmf(x[i])
elif dm_method == 'logpmf':
res = dist.logpmf(x[i])
elif dm_method == 'mean':
res = dist.mean()
elif dm_method == 'var':
res = dist.var()
elif dm_method == 'cov':
res = dist.cov()
if dm_method == 'cov':
cov_matrix = res.tolist() if hasattr(res, 'tolist') else res
for row in cov_matrix:
results.append([float(val) for val in row])
else:
results.append(to_float_list(res))
except Exception as e:
return f"Error: computing {dm_method}: {e}"
return results